# experiment2是baseline+music
# 用的backbone是baseline训练好的backbone
# 5way-5shot-15query-600个task, 平均acc: 0.8228


from lib.algorithms import ALGORITHMS
from .experiment3_net import Experiment3
from .experiment3_loss import Experiment3Loss
from .experiment3_metric import Experiment3Acc
from lib.sampler import NormalSampler
from lib.sampler import EpisodeSampler


@ALGORITHMS.register('experiment3')
def expirement2(cfg):
    return {
        'net': Experiment3(cfg),
        'loss': Experiment3Loss(cfg),
        'metric': Experiment3Acc(cfg),
        'train_sampler': NormalSampler,
        'validate_sampler': NormalSampler,
        'test_sampler': EpisodeSampler
    }